from TransientImage import TransientImage

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
from mpl_toolkits.axes_grid1 import make_axes_locatable

class HistogramPlot:
	image = None #type: TransientImage
	name = None #type: str
	vMin = None #type: float
	vMax = None #type: float
	figure = None
	nonzeroDataSorted = None #type: np.array
	logScale = None #type: bool
	
	def __init__(self, image : TransientImage, name : str="") -> None:
		self.image = image
		self.name = name
		
		self.vMin=np.min(self.image.data)
		self.vMax=np.max(self.image.data)
		
		self.nonzeroDataSorted=self.image.data.flatten()
		self.nonzeroDataSorted=np.sort(self.nonzeroDataSorted[np.nonzero(self.nonzeroDataSorted)])
		
		
		numberOfPixels = self.image.uResolution * self.image.vResolution
		
		self.figure = plt.figure("Histogram "+self.name)
		ax1=plt.gca()
		ax1.plot(np.linspace(0, self.image.data.shape[2], self.image.data.shape[2], endpoint=False), np.sum(self.image.data, axis=(0, 1)) / numberOfPixels)
		ax1.set_xlabel("bins")
		ax1.set_xlim([0, self.image.data.shape[2]-1])
		
		plt.grid()
		ax2 = plt.twiny()   # mirror them
		ax2.set_xlabel("units")
		ax2.set_xlim([self.image.tMin, self.image.tMax])
		plt.tight_layout()
		
		self.timeSliceAx = ax1
		buttonAxes = plt.gcf().add_axes([0.8, 0.025, 0.1, 0.04])
		self.buttonLog = Button(buttonAxes, "Log / Lin")
		self.buttonLog.on_clicked(self.LogButton)
		
	
	def LogButton(self, event):
		self.logScale = not self.logScale
		
		if self.logScale:
			self.timeSliceAx.set_yscale("symlog", linthreshy=self.nonzeroDataSorted[0])
		else:
			self.timeSliceAx.set_yscale("linear")
		self.figure.canvas.draw_idle()
		